package top.lingkang.hibernate5.transaction;

import org.hibernate.Session;
import org.hibernate.StatelessSession;
import org.noear.solon.core.aspect.Interceptor;
import org.noear.solon.core.aspect.Invocation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import top.lingkang.hibernate5.config.SessionFactoryInvocationHandler;

import javax.persistence.EntityManager;
import java.util.List;

/**
 * 事务的拦截实现，进行数据库操作时，应该将此事务拦截的优先级置于最前
 *
 * @author lingkang
 * created by 2023/9/28
 * @since 1.0.0
 */
public class HibernateTranInterceptor implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(HibernateTranInterceptor.class);
    public static final ThreadLocal<Integer> tranNumber = new ThreadLocal<>();

    @Override
    public Object doIntercept(Invocation inv) throws Throwable {
        if (tranNumber.get() == null) {
            tranNumber.set(0);
        } else {
            tranNumber.set(tranNumber.get() + 1);
        }

        Object result = null;
        try {
            result = inv.invoke();
            if (tranNumber.get() == 0) {
                tranNumber.remove();
                commit();
            } else
                tranNumber.set(tranNumber.get() - 1);
        } catch (Throwable e) {
            tranNumber.remove();
            rollback();
            throw e;
        }
        return result;
    }

    private void commit() {
        List<Object> list = SessionFactoryInvocationHandler.threadLocal.get();
        for (Object o : list) {
            if (Session.class.isAssignableFrom(o.getClass())) {
                Session session = (Session) o;
                if (session.getTransaction().isActive())
                    session.getTransaction().commit();
                session.close();
            } else if (EntityManager.class.isAssignableFrom(o.getClass())) {
                EntityManager entityManager = (EntityManager) o;
                if (entityManager.getTransaction().isActive())
                    entityManager.getTransaction().commit();
                entityManager.close();
            } else if (StatelessSession.class.isAssignableFrom(o.getClass())) {
                StatelessSession statelessSession = (StatelessSession) o;
                if (statelessSession.getTransaction().isActive())
                    statelessSession.getTransaction().commit();
                statelessSession.close();
            }
        }
        SessionFactoryInvocationHandler.threadLocal.remove();
    }

    private void rollback() {
        List<Object> list = SessionFactoryInvocationHandler.threadLocal.get();
        for (Object o : list) {
            if (Session.class.isAssignableFrom(o.getClass())) {
                Session session = (Session) o;
                if (session.getTransaction().isActive())
                    try {
                        session.getTransaction().rollback();
                    } catch (Exception e) {
                        log.warn("回滚事务异常：", e);
                    }
                session.close();
            } else if (EntityManager.class.isAssignableFrom(o.getClass())) {
                EntityManager entityManager = (EntityManager) o;
                if (entityManager.getTransaction().isActive())
                    try {
                        entityManager.getTransaction().commit();
                    } catch (Exception e) {
                        log.warn("回滚事务异常：", e);
                    }
                entityManager.close();
            } else if (StatelessSession.class.isAssignableFrom(o.getClass())) {
                StatelessSession statelessSession = (StatelessSession) o;
                if (statelessSession.getTransaction().isActive())
                    try {
                        statelessSession.getTransaction().commit();
                    } catch (Exception e) {
                        log.warn("回滚事务异常：", e);
                    }
                statelessSession.close();
            }
        }

        SessionFactoryInvocationHandler.threadLocal.remove();
    }
}
